import torch
import os, wandb
from network.base_net import RNN_quick, RNN_preAC, RNN
from network.qmix_net import QMixNet
import wandb, csv, random
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.nn import functional as F


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_num_threads(8)


def build_td_lambda_targets(rewards, terminated, mask, target_qs, n_agents, gamma, td_lambda):
    # Assumes  <target_qs > in B*T*A and <reward >, <terminated >, <mask > in (at least) B*T-1*1
    # Initialise  last  lambda -return  for  not  terminated  episodes
    ret = target_qs.new_zeros(*target_qs.shape)
    ret[:, -1] = target_qs[:, -1] * (1 - torch.sum(terminated, dim=1))
    # Backwards  recursive  update  of the "forward  view"
    for t in range(ret.shape[1] - 2, -1, -1):
        ret[:, t] = td_lambda * gamma * ret[:, t + 1] + mask[:, t] \
                    * (rewards[:, t] + (1 - td_lambda) * gamma * target_qs[:, t + 1] * (1 - terminated[:, t]))
    # Returns lambda-return from t=0 to t=T-1, i.e. in B*T-1*A
    return ret[:, 0:-1]


class QMIX:
    def __init__(self, args):
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        input_shape = self.obs_shape
        # 根据参数决定RNN的输入维度
        if args.last_action:
            input_shape += self.n_actions
        if args.reuse_network:
            input_shape += self.n_agents
        setup_seed(args.seed)

        self.p_state = args.p_state
        self.visual_r=args.visual_r
        self.p_s_dim = args.loc_dim

        # 神经网络
        if args.uRNN:
            self.eval_rnn = RNN_preAC(input_shape, args)  # 每个agent选动作的网络
            self.target_rnn = RNN_preAC(input_shape, args)
        else:
            self.eval_rnn = RNN_quick(input_shape, args)  # 每个agent选动作的网络
            self.target_rnn = RNN_quick(input_shape, args)
        self.action_list = torch.linspace(.1, 1, self.n_actions).reshape(1, -1)

        self.eval_qmix_net = QMixNet(args)  # 把agentsQ值加起来的网络
        self.target_qmix_net = QMixNet(args)
        self.args = args
        if self.args.cuda:
            self.eval_rnn.to(torch.device(self.args.GPU))
            self.target_rnn.to(torch.device(self.args.GPU))
            self.eval_qmix_net.to(torch.device(self.args.GPU))
            self.target_qmix_net.to(torch.device(self.args.GPU))
        self.model_dir = args.model_dir +f'/{args.env}/{args.label}/{args.seed}'

        # 如果存在模型则加载模型
        if self.args.load_model:
            if os.path.exists(self.model_dir + f'/{args.step}_rnn_net_params.pkl'):
                path_rnn = self.model_dir +  f'/{args.step}_rnn_net_params.pkl'
                path_qmix = self.model_dir +  f'/{args.step}_qmix_net_params.pkl'
                map_location = 'cuda:0' if self.args.cuda else 'cpu'
                self.eval_rnn.load_state_dict(torch.load(path_rnn, map_location=map_location))
                self.eval_qmix_net.load_state_dict(torch.load(path_qmix, map_location=map_location))
                print('Successfully load the model: {} and {}'.format(path_rnn, path_qmix))
            else:
                raise Exception("No model!")
        self.td_label = True if 'vs' in args.env else False
        self.csv_dir = f'./csv_file/{args.env}/reward/{args.label}'
        self.csv_path = f'{self.csv_dir}/seed_{args.seed}_{args.label}.csv'
        if not os.path.exists(self.csv_dir):
            os.makedirs(self.csv_dir)
        # 让target_net和eval_net的网络参数相同
        self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
        self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())

        self.eval_parameters = list(self.eval_qmix_net.parameters()) + list(self.eval_rnn.parameters())
        if args.optimizer == "RMS":
            self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=args.lr)

        # 执行过程中，要为每个agent都维护一个eval_hidden
        # 学习过程中，要为每个episode的每个agent都维护一个eval_hidden、target_hidden
        self.eval_hidden = None
        self.target_hidden = None
        print('Init alg QMIX')

    def learn(self, batch, max_episode_len, train_step, epsilon=None,
              t_env=0):  # train_step表示是第几次学习，用来控制更新target_net网络的参数
        '''
        在learn的时候，抽取到的数据是四维的，四个维度分别为 1——第几个episode 2——episode中第几个transition
        3——第几个agent的数据 4——具体obs维度。因为在选动作时不仅需要输入当前的inputs，还要给神经网络输入hidden_state，
        hidden_state和之前的经验相关，因此就不能随机抽取经验进行学习。所以这里一次抽取多个episode，然后一次给神经网络
        传入每个episode的同一个位置的transition
        '''
        episode_num = batch['o'].shape[0]
        self.init_hidden(episode_num)
        for key in batch.keys():  # 把batch里的数据转化成tensor
            if key == 'u':
                batch[key] = torch.tensor(batch[key], dtype=torch.long)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)
        s, s_next, u, r, avail_u, avail_u_next, terminated = batch['s'], batch['s_next'], batch['u'], \
                                                             batch['r'], batch['avail_u'], batch['avail_u_next'], \
                                                             batch['terminated']
        mask = 1 - batch["padded"].float()  # 用来把那些填充的经验的TD-error置0，从而不让它们影响到学习

        # 得到每个agent对应的Q值，维度为(episode个数, max_episode_len， n_agents， n_actions)
        q_evals, q_targets, rnn_inputs, rnn_outputs, intrinsic_1, intrinsic_2, intrinsic_rewards, mask_cal, pre_loss = self.get_q_values(
            batch, max_episode_len)

        if t_env > self.args.start_anneal_time:
            if self.args.anneal_type == 'linear':
                intrinsic_rewards = max(1 - self.args.anneal_rate * (
                        t_env - self.args.start_anneal_time) / 1000000, 0) * intrinsic_rewards
            elif self.args.anneal_type == 'exp':
                exp_scaling = (-1) * (1 / self.args.anneal_rate) / np.log(0.01)
                TTT = (t_env - self.args.start_anneal_time) / 1000000
                intrinsic_rewards = intrinsic_rewards * \
                                    min(1, max(0.01, np.exp(-TTT / exp_scaling)))

        if self.args.cuda:
            s = s.to(torch.device(self.args.GPU))
            u = u.to(torch.device(self.args.GPU))
            r = r.to(torch.device(self.args.GPU))
            s_next = s_next.to(torch.device(self.args.GPU))
            terminated = terminated.to(torch.device(self.args.GPU))
            mask = mask.to(torch.device(self.args.GPU))
        # 取每个agent动作对应的Q值，并且把最后不需要的一维去掉，因为最后一维只有一个值了
        q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)

        # 得到target_q
        q_targets[avail_u_next == 0.0] = - 9999999
        q_targets = q_targets.max(dim=3)[0]

        q_total_eval = self.eval_qmix_net(q_evals, s)
        q_total_target = self.target_qmix_net(q_targets, s_next)
        rewards = r + self.args.beta * intrinsic_rewards.mean(dim=-2)
        if self.td_label:
            q_total_target = torch.cat((q_total_target[:, 0, ...].unsqueeze(1), q_total_target), dim=1)
            targets = build_td_lambda_targets(rewards, terminated, mask, q_total_target,
                                              self.args.n_agents, self.args.gamma, 0.6)
        else:
            targets = r + self.args.gamma * q_total_target * (1 - terminated)

        td_error = (q_total_eval - targets.detach())
        mask_pre_error = pre_loss * mask
        mask_pre_loss = mask_pre_error.sum() / mask.sum()
        masked_td_error = mask * td_error  # 抹掉填充的经验的td_error
        # pi_loss=(intrinsic_2.squeeze()*mask).sum()/mask.sum()
        # 不能直接用mean，因为还有许多经验是没用的，所以要求和再比真实的经验数，才是真正的均值
        loss = (masked_td_error ** 2).sum() / mask.sum() + mask_pre_loss
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.args.grad_norm_clip)
        self.optimizer.step()

        if random.random() < 0.005:
            mask = mask.unsqueeze(-2).repeat(1, 1, self.n_agents, 1)  # self.args.beta
            # (pre_loss * mask).sum() / mask.sum()
            self.writereward(self.csv_path, mask, self.args.beta * intrinsic_1 * mask,
                             self.args.beta * intrinsic_2 * mask,
                             self.args.beta * intrinsic_rewards * mask,
                             r * mask[..., 0], masked_td_error, loss, t_env, mask_pre_loss)

        # for index in BatchSampler(SubsetRandomSampler(range(rnn_inputs.shape[0])), 256, False):
        #     # vae_mu,vae_log=self.eval_vae.update(vae_inputs[index],vae_outputs[index],mask_re[index])
        #     self.eval_rnn.update((rnn_inputs[index]), rnn_outputs[index], mask_cal[index])

        if train_step > 0 and train_step % self.args.target_update_cycle == 0:
            self.target_rnn.load_state_dict(self.eval_rnn.state_dict())
            self.target_qmix_net.load_state_dict(self.eval_qmix_net.state_dict())

    def _get_inputs_matrix(self, batch):
        inputs, inputs_next = batch['o'], batch['o_next']
        u_onehot = batch['u_onehot']
        obs_clone = inputs.clone()
        obs_next_clone = inputs_next.clone()

        if self.args.last_action:
            u_onehot_f = torch.zeros_like(u_onehot)
            u_onehot_f[:, 1:, :, :] = u_onehot[:, :-1, :, :]

            inputs = torch.cat([inputs, u_onehot_f],
                               dim=-1)  # 观测加上上一步动作 torch.Size([8, 63, 3, 19]) to torch.Size([8, 63, 3, 45])
            inputs_next = torch.cat([inputs_next, u_onehot], dim=-1)
        vae_inputs = torch.cat([obs_clone, u_onehot], dim=-1)
        add_id = torch.eye(self.args.n_agents).type_as(inputs).expand(
            [inputs.shape[0], inputs.shape[1], -1, -1])

        if self.args.reuse_network:  # False
            inputs = torch.cat([inputs, add_id], dim=-1)
            inputs_next = torch.cat([inputs_next, add_id], dim=-1)

        return inputs, inputs_next, obs_clone, obs_next_clone, add_id, vae_inputs, u_onehot

    def _get_inputs(self, batch, transition_idx):
        # 取出所有episode上该transition_idx的经验，u_onehot要取出所有，因为要用到上一条
        obs, obs_next, u_onehot = batch['o'][:, transition_idx], \
                                  batch['o_next'][:, transition_idx], batch['u_onehot'][:]
        episode_num = obs.shape[0]
        inputs, inputs_next = [], []
        inputs.append(obs)
        inputs_next.append(obs_next)
        # 给obs添加上一个动作、agent编号

        if self.args.last_action:
            if transition_idx == 0:  # 如果是第一条经验，就让前一个动作为0向量
                inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
            else:
                inputs.append(u_onehot[:, transition_idx - 1])
            inputs_next.append(u_onehot[:, transition_idx])
        if self.args.reuse_network:
            # 因为当前的obs三维的数据，每一维分别代表(episode编号，agent编号，obs维度)，直接在dim_1上添加对应的向量
            # 即可，比如给agent_0后面加(1, 0, 0, 0, 0)，表示5个agent中的0号。而agent_0的数据正好在第0行，那么需要加的
            # agent编号恰好就是一个单位矩阵，即对角线为1，其余为0
            inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
            inputs_next.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
        # 要把obs中的三个拼起来，并且要把episode_num个episode、self.args.n_agents个agent的数据拼成40条(40,96)的数据，
        # 因为这里所有agent共享一个神经网络，每条数据中带上了自己的编号，所以还是自己的数据
        inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
        inputs_next = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs_next], dim=1)
        return inputs, inputs_next

    def get_q_values(self, batch, max_episode_len):
        inputs, inputs_next, obs, obs_next, add_id, vae_inputs, u_onehot = self._get_inputs_matrix(
            batch)
        episode_num = batch['o'].shape[0]
        q_evals, q_targets = [], []
        inputs_shape = inputs.shape  # torch.Size([8, 64, 3, 45])
        mask_cal = 1 - batch["padded"].float()  # torch.Size([8, 63, 1])
        s = batch['s']
        p_state = batch['p_state']
        if self.args.cuda:
            self.eval_rnn.to(torch.device(self.args.GPU))
            self.target_rnn.to(torch.device(self.args.GPU))
            inputs = inputs.to(torch.device(self.args.GPU))
            inputs_next = inputs_next.to(torch.device(self.args.GPU))
            vae_inputs = vae_inputs.to(torch.device(self.args.GPU))
            obs = obs.to(torch.device(self.args.GPU))
            mask_cal = mask_cal.to(torch.device(self.args.GPU))
            s = s.to(torch.device(self.args.GPU))
            p_state = p_state.to(torch.device(self.args.GPU))

            self.eval_hidden = self.eval_hidden.to(torch.device(self.args.GPU))  # torch.Size([8, 3, 64])
            self.target_hidden = self.target_hidden.to(
                torch.device(self.args.GPU))
        shift_t = 2
        p_dim0, p_dim1, p_dim2, _ = p_state.shape
        p_index = ((1 - torch.eye(self.n_agents)) == 1).reshape(1, 1, -1, 1).repeat(p_dim0, p_dim1,
                                                                                    1, self.p_state).to(p_state.device)
        p_state_loc = (p_state[..., :self.p_s_dim]).repeat(1, 1, 1, self.n_agents).reshape(p_state.shape[0],
                                                                                      p_state.shape[1], -1, self.p_s_dim)
        p_t1_state = p_state.clone().roll(dims=1, shifts=-shift_t)
        p_t1_state[:, -shift_t:] = (p_state[:, -1]).unsqueeze(1).clone()
        p_state_neg_loc = (p_t1_state[..., :self.p_s_dim]).repeat(1, 1, self.n_agents, 1)
        # rel_p_state = ((p_state_neg_loc - p_state_loc)[p_index]).reshape(p_state.shape[0], p_state.shape[1],p_state.shape[2], -1)
        rel_p_state = p_state_neg_loc - p_state_loc
        #dst_p=(rel_p_state.norm(p=2,dim=-1)<self.visual_r).type_as(rel_p_state)
        p_ball = (p_t1_state[..., self.p_s_dim:]).repeat(1, 1, self.n_agents, 1)
        p_s = torch.cat((rel_p_state, p_ball), dim=-1)
        # p_ball = (((p_t1_state[..., self.p_s_dim:]).repeat(1, 1, self.n_agents, 1))[p_index[...,4-self.p_state:]]).reshape(p_state.shape[0],
        #                                                                                          p_state.shape[1],
        #                                                                                          p_state.shape[2], -1)
        # rel_p_state=rel_p_state.reshape(p_dim0,p_dim1,p_dim2,self.n_agents-1,-1)
        # p_ball=p_ball.reshape(p_dim0,p_dim1,p_dim2,self.n_agents-1,-1)
        # p_s = torch.cat((rel_p_state, p_ball), dim=-1).reshape(p_dim0,p_dim1,p_dim2,-1)
        # p_s=rel_p_state
        eval_h = self.eval_hidden.view(-1, self.args.rnn_hidden_dim)  # torch.Size([24, 64]) （batch,hidden size）
        target_h = self.target_hidden.view(-1, self.args.rnn_hidden_dim)  # torch.Size([24, 64])
        inputs = inputs.permute(0, 2, 1, 3)  # torch.Size([8, 3, 64, 45])
        inputs_next = inputs_next.permute(0, 2, 1, 3)

        inputs = inputs.reshape(-1, inputs.shape[2], inputs.shape[3])  # torch.Size([24, 64, 45])
        inputs_next = inputs_next.reshape(-1,
                                          inputs_next.shape[2], inputs_next.shape[3])
        p_s_copy = p_s[p_index].reshape(p_s.shape[0], p_s.shape[1], -1, (self.n_agents - 1) * self.p_state)
        q_evals, out_eval_h, pre_state,ally_mask = self.eval_rnn(inputs,
                                                       eval_h)  # torch.Size([24, 64, 19])torch.Size([24,64, 64])
        q_targets, out_target_h, _,_ = self.target_rnn(inputs_next, target_h)
        pre_state = pre_state.reshape(-1, self.n_agents, pre_state.shape[1], pre_state.shape[2]).permute(0, 2, 1, 3)
        pre_loss_mask=ally_mask.reshape(-1, self.n_agents, ally_mask.shape[1], ally_mask.shape[2]).permute(0, 2, 1, 3)
        #if pre_loss_mask.min()
        #pre_loss_mask=(dst_p.unsqueeze(-1).repeat(1,1,1,self.p_state))[p_index].reshape(p_dim0,p_dim1,self.n_agents,-1)
        #pre_loss = (F.mse_loss(pre_state, p_s_copy, reduction='none')).sum(dim=-1).sum(dim=-1, keepdim=True)
        pre_loss = (F.mse_loss(pre_state, p_s_copy, reduction='none')*pre_loss_mask).sum(dim=-1).sum(dim=-1, keepdim=True)

        # pre_state=pre_state.reshape(episode_num,-1,pre_state.shape[-2],pre_state.shape[-1]).permute(0,2,1,3)
        # p_weight=((pre_state-p_s)**2).sum(axis=-1,keepdim=True)
        # with torch.no_grad():
        #     q_ta, _ = self.target_rnn(inputs, target_h)

        with torch.no_grad():
            # _, pi_diverge = self.target_rnn.get_intrinsic(inputs[:, :-1], q_eval_global[:, :-1], u_onehot, target_h)
            # pi_loss = self.eval_rnn.get_loss(inputs, p_s, eval_h, q_evals.clone().detach(),
            #                                  out_eval_h)  # q_evals.clone().detach()
            p_s_intrin = p_s.clone()
            p_s_intrin[p_index] = pre_state.reshape(-1)
            pi_random_diverge, pi_diverge = self.target_rnn.get_intrinsic(inputs.clone(), p_s_intrin, p_index.clone(),
                                                                          target_h,
                                                                          q_evals.clone(),pre_loss_mask)
            mask_cal = mask_cal.unsqueeze(-2).expand(obs.shape[:-1] + mask_cal.shape[-1:])
            mask_cal = mask_cal.reshape(-1, mask_cal.shape[-1])  # torch.Size([1512, 1])
            dim0_batches, dim1_steps, dim2_agents, dim3_obs = obs.shape
            rnn_inputs = torch.cat((out_eval_h, inputs), dim=-1)
            rnn_inputs = rnn_inputs.reshape(dim0_batches, dim2_agents, dim1_steps, rnn_inputs.shape[-1]).permute(0, 2,
                                                                                                                 1, 3)

            rnn_inputs = rnn_inputs.reshape(-1, rnn_inputs.shape[-1])
            rnn_outputs = (p_s[p_index]).reshape(-1, self.p_state * (self.n_agents - 1))
            intrinsic_1 = self.args.beta1 * pi_random_diverge  # torch.Size([3600, 1])
            intrinsic_2 = self.args.beta2 * pi_diverge  # * p_weight
            intrinsic_2 = intrinsic_2.clamp(min=0, max=0.5)
            intrinsic_rewards = intrinsic_1 + intrinsic_2

        q_evals = q_evals.reshape(episode_num, -1, q_evals.shape[-2], q_evals.shape[-1]).permute(0, 2, 1, 3)
        q_targets = q_targets.reshape(episode_num, -1, q_targets.shape[-2], q_targets.shape[-1]).permute(0, 2, 1, 3)
        return q_evals, q_targets, rnn_inputs, rnn_outputs, intrinsic_1, intrinsic_2, intrinsic_rewards, mask_cal, pre_loss

    def init_hidden(self, episode_num):
        # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden
        self.eval_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
        self.target_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))

    def save_model(self, train_step):
        num = str(train_step // self.args.save_cycle)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        torch.save(self.eval_qmix_net.state_dict(), self.model_dir + '/' + num + '_qmix_net_params.pkl')
        torch.save(self.eval_rnn.state_dict(), self.model_dir + '/' + num + '_rnn_net_params.pkl')

    def build_q_lambda_targets(self, rewards, terminated, mask, exp_qvals, qvals, gamma, td_lambda):
        # Assumes  <target_qs > in B*T*A and <reward >, <terminated >, <mask > in (at least) B*T-1*1
        # Initialise  last  lambda -return  for  not  terminated  episodes
        ret = exp_qvals.new_zeros(*exp_qvals.shape)
        ret[:, -1] = exp_qvals[:, -1] * (1 - torch.sum(terminated, dim=1))
        # Backwards  recursive  update  of the "forward  view"
        for t in range(ret.shape[1] - 2, -1, -1):
            reward = rewards[:, t] + exp_qvals[:, t] - qvals[:, t]  # off-policy correction
            ret[:, t] = td_lambda * gamma * ret[:, t + 1] + mask[:, t] \
                        * (reward + (1 - td_lambda) * gamma * exp_qvals[:, t + 1] * (1 - terminated[:, t]))
        # Returns lambda-return from t=0 to t=T-1, i.e. in B*T-1*A
        return ret[:, 0:-1]

    def writereward(self, path, mask, intrin_1, intrin_2, intrin_sum, reward, td_error, loss, step, pre_loss):
        mask_elems = mask.sum()
        intrin_reward_mean_1 = intrin_1.sum() / mask_elems
        intrinsic_rewards_mean = intrin_sum.sum() / mask_elems
        intrinsic_rewards_sum = intrin_sum.squeeze().sum(axis=-2).reshape(-1).mean()

        intrin_reward_sum_1 = intrin_1.squeeze().sum(axis=-2).reshape(-1).mean()
        intrin_reward_mean_2 = intrin_2.sum() / mask_elems
        intrin_reward_sum_2 = intrin_2.squeeze().sum(axis=-2).reshape(-1).mean()
        reward = reward.squeeze().sum(axis=-1).reshape(-1).mean()
        if self.args.wandb:
            wandb.log(
                {'step': step, "intrin_reward_mean_1": intrin_reward_mean_1,
                 "intrin_reward_sum_1": intrin_reward_sum_1, 'pre_loss': pre_loss,
                 'intrinsic_rewards_mean': intrinsic_rewards_mean, 'intrinsic_rewards_sum': intrinsic_rewards_sum})
            # wandb.log({f"{phase} MSE loss": epoch_mse_loss})
            wandb.log(
                {"intrinsic_rewards_mean_2": intrin_reward_mean_2, "intrinsic_rewards_sum_2 ": intrin_reward_sum_2})
            wandb.log({" td_error": td_error, " Training reward": reward, " loss": loss})

        if os.path.isfile(path):
            with open(path, 'a+') as f:
                csv_write = csv.writer(f)
                csv_write.writerow(
                    [step, intrin_reward_mean_1.item(), intrin_reward_sum_1.item(), intrin_reward_mean_2.item(),
                     intrin_reward_sum_2.item(), reward.item(),
                     reward.item() + intrin_reward_sum_1.item() + intrin_reward_sum_2.item(),
                     td_error.mean().item(), loss.item()])

        else:
            with open(path, 'w') as f:
                csv_write = csv.writer(f)
                csv_write.writerow(
                    ['step', 'intrinsic1_mean', 'intrinsic1_sum', 'intrinsic_rewards_mean', 'intrinsic_rewards_sum',
                     'reward', 'sum',
                     'td_error', 'loss'])
                csv_write.writerow(
                    [step, intrin_reward_mean_1.item(), intrin_reward_sum_1.item(), intrin_reward_mean_2.item(),
                     intrin_reward_sum_2.item(), reward.item(),
                     reward.item() + intrin_reward_sum_1.item() + intrin_reward_sum_2.item(),
                     td_error.mean().item(), loss.item()])
